Chapter 2: Small Worlds And Large Worlds¶
[1]:
%load_ext jupyter_black
[2]:
from typing import Sequence
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import pymc as pm
from scipy import stats
from jax import random as jrandom
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
seed = 84735
pio.templates.default = "plotly_white"
WARNING:pytensor.tensor.blas:Using NumPy C-API based implementation for BLAS functions.
Code¶
Code 2.1¶
[3]:
ways = jnp.array([0, 3, 8, 9, 0])
ways / jnp.sum(ways)
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[3]:
Array([0. , 0.15, 0.4 , 0.45, 0. ], dtype=float32)
Code 2.2¶
[4]:
jnp.exp(dist.Binomial(total_count=9, probs=0.5).log_prob(6))
[4]:
Array(0.16406256, dtype=float32)
[5]:
stats.binom.pmf(k=6, n=9, p=[0.5])
[5]:
array([0.1640625])
Code 2.3¶
[6]:
n_trials = 9
n_successes = 6
grid_size = 20
prior = jnp.full(grid_size, 1)
[7]:
def calculate_grid_approximation_posterior_numpyro(
n_trials: int,
n_successes: int,
prior: Sequence[float],
grid_size: int,
):
grid = jnp.linspace(0, 1, grid_size)
likelihood = jnp.exp(
dist.Binomial(total_count=n_trials, probs=grid).log_prob(n_successes)
)
raw_posterior = prior * likelihood
posterior = raw_posterior / raw_posterior.sum()
return posterior
posterior_numpyro = calculate_grid_approximation_posterior_numpyro(
n_trials, n_successes, prior, grid_size
)
posterior_numpyro
[7]:
Array([0.0000000e+00, 7.9898348e-07, 4.3077191e-05, 4.0907954e-04,
1.8938888e-03, 5.8738771e-03, 1.4042934e-02, 2.7851753e-02,
4.7801144e-02, 7.2807401e-02, 9.9872954e-02, 1.2426433e-01,
1.4031431e-01, 1.4283489e-01, 1.2894331e-01, 9.9872909e-02,
6.2058900e-02, 2.6454771e-02, 4.6596657e-03, 7.4891537e-20], dtype=float32)
[8]:
def calculate_grid_approximation_posterior_pymc(
n_trials: int,
n_successes: int,
prior: Sequence[float],
grid_size: int,
):
grid = jnp.linspace(0, 1, grid_size)
likelihood = stats.binom.pmf(k=n_successes, n=n_trials, p=grid)
raw_posterior = prior * likelihood
posterior = raw_posterior / raw_posterior.sum()
return posterior
posterior_pymc = calculate_grid_approximation_posterior_pymc(
n_trials, n_successes, prior, grid_size
)
posterior_pymc
[8]:
Array([0.0000000e+00, 7.9898376e-07, 4.3077172e-05, 4.0907960e-04,
1.8938874e-03, 5.8738743e-03, 1.4042936e-02, 2.7851744e-02,
4.7801159e-02, 7.2807401e-02, 9.9872977e-02, 1.2426433e-01,
1.4031433e-01, 1.4283489e-01, 1.2894326e-01, 9.9872947e-02,
6.2058900e-02, 2.6454777e-02, 4.6596681e-03, 0.0000000e+00], dtype=float32)
Code 2.4¶
[9]:
def plot_grid_approximation(prior, posterior, *, title=None, grid_size=20):
grid = jnp.linspace(0, 1, grid_size)
title = title or f"Grid Approximation of Posterior Distribution"
prior /= prior.sum()
posterior /= posterior.sum()
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=grid,
y=prior,
name="prior",
mode="lines+markers",
)
)
fig.add_trace(
go.Scatter(
x=grid,
y=posterior,
name="posterior",
mode="lines+markers",
)
)
fig.update_layout(
title=title,
xaxis={"title": "p"},
yaxis={"title": "posterior probability"},
)
fig.show()
return fig
fig = plot_grid_approximation(prior, posterior_pymc)
Code 2.5¶
[10]:
grid = jnp.linspace(0, 1, grid_size)
prior = jnp.where(grid < 0.5, 0, 1)
posterior = calculate_grid_approximation_posterior_numpyro(
n_trials, n_successes, prior, grid_size
)
fig = plot_grid_approximation(prior, posterior)
[11]:
grid = jnp.linspace(0, 1, grid_size)
prior = jnp.exp(-5 * jnp.abs(grid - 0.5))
posterior = calculate_grid_approximation_posterior_pymc(
n_trials, n_successes, prior, grid_size
)
fig = plot_grid_approximation(prior, posterior)
Code 2.6¶
[12]:
W = 6
L = 3
n_steps = 1_000
n_samples = 1_000
[13]:
def calculate_quadratic_approximation_posterior_numpyro(
W,
L,
n_steps,
n_samples,
):
def model(W, L):
p = numpyro.sample("p", dist.Uniform(0, 1))
numpyro.sample("W", dist.Binomial(total_count=W + L, probs=p), obs=W)
guide = AutoLaplaceApproximation(model)
loss = Trace_ELBO()
learning_rate = 1
optimizer = optim.Adam(learning_rate)
rng_key_train, rng_key_sample = jrandom.split(jrandom.PRNGKey(seed))
svi = SVI(model, guide, optimizer, loss, W=W, L=L)
svi_result = svi.run(rng_key_train, num_steps=n_steps)
samples = guide.sample_posterior(
rng_key_sample, params=svi_result.params, sample_shape=(n_samples,)
)
return svi_result, samples
# display summary of quadratic approximation
svi_result, samples = calculate_quadratic_approximation_posterior_numpyro(
W=W, L=L, n_steps=n_steps, n_samples=n_samples
)
numpyro.diagnostics.print_summary(samples, prob=0.89, group_by_chain=False)
100%|██████████████████████████| 1000/1000 [00:00<00:00, 2024.51it/s, init loss: 4.4362, avg. loss [951-1000]: 2.7795]
mean std median 5.5% 94.5% n_eff r_hat
p 0.63 0.14 0.65 0.41 0.84 1109.70 1.00
[14]:
def calculate_quadratic_approximation_posterior_pymc(W, L, n_steps, n_samples):
raise NotImplementedError
Code 2.7¶
[15]:
# analtyical calculation
W = 6
L = 3
x = jnp.linspace(0, 1, 101)
analytical_posterior = jnp.exp(dist.Beta(W + 1, L + 1).log_prob(x))
quad_posterior = jnp.exp(
dist.Normal(loc=samples["p"].mean(), scale=samples["p"].std()).log_prob(x)
)
fig = go.Figure(
[
go.Scatter(
x=x,
y=analytical_posterior,
name="analytical",
),
go.Scatter(
x=x,
y=quad_posterior,
name="quadratic",
),
]
)
fig.show()
Code 2.8¶
[16]:
n_samples = 1_000
p = [jnp.nan] * n_samples
p[0] = 0.5
W = 6
L = 3
with numpyro.handlers.seed(rng_seed=seed):
for i in range(1, n_samples):
p_new = numpyro.sample("p_new", dist.Normal(p[i - 1], 0.1))
p_new = jnp.abs(p_new) if p_new < 0 else p_new
p_new = 2 - p_new if p_new > 1 else p_new
q0 = jnp.exp(dist.Binomial(total_count=W + L, probs=p[i - 1]).log_prob(W))
q1 = jnp.exp(dist.Binomial(total_count=W + L, probs=p_new).log_prob(W))
u = numpyro.sample("u", dist.Uniform(0, 1))
p[i] = float(p_new if u < q1 / q0 else p[i - 1])
Code 2.9¶
[17]:
ax = az.plot_density({"p": p}, hdi_prob=1)
plt.plot(x, analytical_posterior, "--")
[17]:
[<matplotlib.lines.Line2D at 0x79cb665a39e0>]
Medium¶
2M1¶
[18]:
def plot_grid_posterior(data, *, grid_size=100):
prior = jnp.array([1 / grid_size] * grid_size)
grid = jnp.linspace(0, 1, grid_size)
n_success = jnp.sum(jnp.array([x == "W" for x in data]))
likelihood = jnp.exp(
dist.Binomial(total_count=len(data), probs=grid).log_prob(n_success)
)
posterior = prior * likelihood
posterior /= jnp.sum(posterior)
fig = go.Figure(
[
go.Scatter(
x=grid,
y=prior,
name="prior",
),
go.Scatter(
x=grid,
y=posterior,
name="posteriror",
),
]
)
fig.update_layout(
title=",".join(data),
xaxis={"title": "p"},
yaxis={"title": "probability"},
)
return {"posterior": posterior, "fig": fig}
for data in [
["W", "W", "W"],
["W", "W", "W", "L"],
["L", "W", "W", "L", "W", "W", "W"],
]:
results = plot_grid_posterior(data)
results["fig"].show()
2M2¶
[19]:
def plot_grid_posterior(data, *, grid_size=100):
grid = jnp.linspace(0, 1, grid_size)
prior = jnp.array([0 if p < 0.5 else 1 for p in grid])
prior /= prior.sum()
n_success = jnp.sum(jnp.array([x == "W" for x in data]))
likelihood = jnp.exp(
dist.Binomial(total_count=len(data), probs=grid).log_prob(n_success)
)
posterior = prior * likelihood
posterior /= jnp.sum(posterior)
fig = go.Figure(
[
go.Scatter(
x=grid,
y=prior,
name="prior",
),
go.Scatter(
x=grid,
y=posterior,
name="posteriror",
),
]
)
fig.update_layout(
title=",".join(data),
xaxis={"title": "p"},
yaxis={"title": "probability"},
)
return {"posterior": posterior, "fig": fig}
for data in [
["W", "W", "W"],
["W", "W", "W", "L"],
["L", "W", "W", "L", "W", "W", "W"],
]:
results = plot_grid_posterior(data)
results["fig"].show()
2M3¶
2M4¶
card 1 has 2 ways of showing observed data
card 2 has 1 way of showing observed data
card 3 has 0 ways of showing observed data
Of the 3 possibles ways of showing observed data, only 2 ways to have the other side also black, hence 2/3 chances the other side is black.
2M5¶
Now 5 ways of seeing observed data and of those 5, 4 ways that the other side is also black. p=4/5
2M6¶
card 1 has 2 ways of yielding the observed data
card 2 has 2 ways of yielding the observed data
card 3 has 0 ways of yielding the observed data
Of the 4 possible ways of getting the observed data, only 2 ways ot have the other side also black. p=0.5
2M7¶
cards |
count |
|---|---|
1,2 |
2 |
1,3 |
4 |
2,1 |
0 |
2,3 |
2 |
3,1 |
0 |
3,2 |
0 |
p = 6 / 8
Hard¶
2H1¶
Notation:
First we calculate the posterior probability of panda being of species A:
Armed with that, we calculate the probability of a second twin (assuming independance between births):
[20]:
p_twins = [0.1, 0.2]
n_samples = 1_000
second_birth_is_twins = [jnp.nan] * n_samples
with numpyro.handlers.seed(rng_seed=seed):
i = 0
while i < n_samples:
species = numpyro.sample("u", dist.Categorical(jnp.array([0.5, 0.5])))
first_birth_is_twins = numpyro.sample(
"first_birth_is_twins", dist.Bernoulli(probs=p_twins[species])
)
if not first_birth_is_twins:
continue
second_birth_is_twins[i] = numpyro.sample(
"second_birth_is_twins", dist.Bernoulli(probs=p_twins[species])
)
i += 1
print(f"P(T, T| T) = {jnp.array(second_birth_is_twins).mean():.2f}.")
P(T, T| T) = 0.17.
[21]:
p_twins = [0.1, 0.2]
n_samples = 5_000
rng = np.random.default_rng(seed=seed)
second_birth_is_twins = [jnp.nan] * n_samples
i = 0
while i < n_samples:
species = rng.choice([0, 1])
first_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species], random_state=rng)
if not first_birth_is_twins:
continue
second_birth_is_twins[i] = stats.bernoulli.rvs(p=p_twins[species], random_state=rng)
i += 1
print(f"P(T, T| T) = {jnp.array(second_birth_is_twins).mean():.2f}.")
P(T, T| T) = 0.17.
2H2¶
As per above \(P(A) = 1/3\)
[22]:
p_twins = [0.1, 0.2]
n_samples = 1_000
is_species_a = [jnp.nan] * n_samples
with numpyro.handlers.seed(rng_seed=seed):
i = 0
while i < n_samples:
species = numpyro.sample("u", dist.Categorical(jnp.array([0.5, 0.5])))
first_birth_is_twins = numpyro.sample(
"first_birth_is_twins", dist.Bernoulli(probs=p_twins[species])
)
if not first_birth_is_twins:
continue
is_species_a[i] = 1 - species
i += 1
print(f"P(A) = {jnp.array(is_species_a).mean():.2f}.")
P(A) = 0.33.
[23]:
p_twins = [0.1, 0.2]
n_samples = 2_500
rng = np.random.default_rng(seed=seed)
is_species_a = [jnp.nan] * n_samples
i = 0
j = 0
while i < n_samples:
j += 1
species = rng.choice([0, 1])
first_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species], random_state=rng)
if not first_birth_is_twins:
continue
is_species_a[i] = species == 0
i += 1
print(f"P(A) = {jnp.array(is_species_a).mean():.2f}.")
P(A) = 0.33.
2H3¶
[24]:
p_twins = [0.1, 0.2]
n_samples = 1_000
is_species_a = [jnp.nan] * n_samples
with numpyro.handlers.seed(rng_seed=seed):
i = 0
while i < n_samples:
species = numpyro.sample("u", dist.Categorical(jnp.array([0.5, 0.5])))
first_birth_is_twins = numpyro.sample(
"first_birth_is_twins", dist.Bernoulli(probs=p_twins[species])
)
if not first_birth_is_twins:
continue
second_birth_is_twins = numpyro.sample(
"second_birth_is_twins", dist.Bernoulli(probs=p_twins[species])
)
if second_birth_is_twins:
continue
is_species_a[i] = 1 - species
i += 1
print(f"P(A | T') = {jnp.array(is_species_a).mean():.2f}")
P(A | T') = 0.36
[25]:
p_twins = [0.1, 0.2]
n_samples = 1_000
rng = np.random.default_rng(seed=seed)
is_species_a = [jnp.nan] * n_samples
i = 0
while i < n_samples:
species = rng.choice([0, 1])
first_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species], random_state=rng)
if not first_birth_is_twins:
continue
second_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species], random_state=rng)
if second_birth_is_twins:
continue
is_species_a[i] = species == 0
i += 1
print(f"P(A | T') = {jnp.array(is_species_a).mean():.2f}")
P(A | T') = 0.37
2H4¶
Ignoring the birth data:
With the birth data, our prior is now \(P(A) = 0.36\):
[26]:
no_births_analytical = (0.8 * 0.5) / (0.8 * 0.5 + (1 - 0.65) * 0.5)
with_births_analytical = (0.8 * 0.36) / (0.8 * 0.36 + (1 - 0.65) * 0.64)
[27]:
p_twins = [0.1, 0.2]
p_test_says_a = [0.8, 1 - 0.65]
n_samples = 1_000
is_species_a_no_births = []
is_species_a_with_births = []
with numpyro.handlers.seed(rng_seed=seed):
i_no_births = 0
i_with_births = 0
while min(i_no_births, i_with_births) < n_samples:
species = numpyro.sample("u", dist.Categorical(jnp.array([0.5, 0.5])))
test_says_a = numpyro.sample(
"test_says_a", dist.Bernoulli(probs=p_test_says_a[species])
)
if test_says_a:
is_species_a_no_births.append(1 - species)
i_no_births += 1
else:
continue
total_twin_births = numpyro.sample(
"total_twin_births", dist.Binomial(total_count=2, probs=p_twins[species])
)
if total_twin_births != 1:
continue
is_species_a_with_births.append(1 - species)
i_with_births += 1
no_births_mean = jnp.array(is_species_a_no_births).mean()
no_births_z = jnp.array(is_species_a_no_births).std() * jnp.sqrt(n_samples)
no_births_t = abs(no_births_mean - no_births_analytical) / no_births_z
with_births_mean = jnp.array(is_species_a_with_births).mean()
with_births_z = jnp.array(is_species_a_with_births).std() * jnp.sqrt(n_samples)
with_births_t = abs(with_births_mean - with_births_analytical) / with_births_z
print(
f"P(A|A_hat) = {no_births_mean:.4f} (t-stat: {no_births_t:.4f})\n"
f"P(A|T, T', A_hat) = {with_births_mean:.4f} (t-stat {with_births_t:.4f})"
)
P(A|A_hat) = 0.7064 (t-stat: 0.0007)
P(A|T, T', A_hat) = 0.5680 (t-stat 0.0004)
[28]:
p_twins = [0.1, 0.2]
p_test_says_a = [0.8, 1 - 0.65]
n_samples = 2_000
rng = np.random.default_rng(seed=seed)
is_species_a_no_births = []
is_species_a_with_births = []
i_no_births = 0
i_with_births = 0
while min(i_no_births, i_with_births) < n_samples:
species = rng.choice([0, 1])
test_says_a = stats.bernoulli.rvs(p=p_test_says_a[species])
if test_says_a:
is_species_a_no_births.append(1 - species)
i_no_births += 1
else:
continue
first_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species])
if not first_birth_is_twins:
continue
second_birth_is_twins = stats.bernoulli.rvs(p=p_twins[species])
if second_birth_is_twins:
continue
is_species_a_with_births.append(1 - species)
i_with_births += 1
no_births_mean = jnp.array(is_species_a_no_births).mean()
no_births_z = jnp.array(is_species_a_no_births).std() * jnp.sqrt(n_samples)
no_births_t = abs(no_births_mean - no_births_analytical) / no_births_z
with_births_mean = jnp.array(is_species_a_with_births).mean()
with_births_z = jnp.array(is_species_a_with_births).std() * jnp.sqrt(n_samples)
with_births_t = abs(with_births_mean - with_births_analytical) / with_births_z
print(
f"P(A|A_hat) = {no_births_mean:.4f} (t-stat: {no_births_t:.4f})\n"
f"P(A|T, T', A_hat) = {with_births_mean:.4f} (t-stat {with_births_t:.4f})"
)
P(A|A_hat) = 0.6936 (t-stat: 0.0001)
P(A|T, T', A_hat) = 0.5510 (t-stat 0.0005)
[ ]: